#!/usr/bin/env python3
"""
T3 helper — build true-size lens table from KiDS multi-band tiles.

Fixes:
- FITS endianness (Table.read → to_pandas)
- Robust mass semantics (auto-detect; NO double-log)
- Drop LePhare sentinels (e.g., -99) and implausible logs (<6 or >13)
- One-time tile scan with gzip cache (data/lens_sizes_raw.csv.gz)
- Ignores header/meta files (*.fits.wc, *.L1.fits)

Output CSV:
  lens_id, ra_deg, dec_deg, z_lens, R_G_kpc, Mstar_log10, R_G_bin, Mstar_bin
"""

import os, glob
import numpy as np
import pandas as pd
from astropy.io import fits
from astropy.table import Table
from astropy.cosmology import FlatLambdaCDM

MULTIBAND_DIR = "data/kids_multiband"
BRIGHT = "data/KiDS_DR4_brightsample.fits"
LEPH   = "data/KiDS_DR4_brightsample_LePhare.fits"
CACHE  = "data/lens_sizes_raw.csv.gz"
OUTCSV = "data/lenses_true.csv"

COSMO = FlatLambdaCDM(H0=70, Om0=0.3)  # diagnostics-only
PIX_ARCSEC = 0.213

RG_EDGES = [1.5, 3.0, 5.0, 8.0, 12.0]  # kpc (frozen)
MS_EDGES = [10.2, 10.5, 10.8, 11.1]    # log10(M/Msun) (frozen)


def _first_table_hdu_index(hdul):
    for i, hdu in enumerate(hdul):
        if isinstance(hdu, fits.BinTableHDU):
            return i
    return None


def fits_to_df(path: str) -> pd.DataFrame:
    with fits.open(path, memmap=True) as hdul:
        idx = _first_table_hdu_index(hdul)
        if idx is None:
            raise SystemExit(f"No BinTableHDU found in {path}")
    t = Table.read(path, format="fits", hdu=idx, memmap=True)
    df = t.to_pandas()
    for c in df.select_dtypes(include=[np.number]).columns:
        arr = df[c].to_numpy()
        if arr.dtype.byteorder in (">",):
            df[c] = arr.byteswap().newbyteorder()
    return df


def pick_mass_series(df_leph: pd.DataFrame) -> tuple[str, pd.Series]:
    for col in ["MASS_MED", "MASS_BEST", "LOGMASS", "LOGMSTAR", "MASS_TOT", "MSTAR"]:
        if col in df_leph.columns:
            s = pd.to_numeric(df_leph[col], errors="coerce")
            if s.notna().sum() > 1000:
                return col, s
    raise SystemExit("No usable mass column found in LePhare file.")


def as_log10(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce").replace([np.inf, -np.inf], np.nan)
    if s.notna().sum() == 0:
        return s
    q50 = np.nanquantile(s, 0.5)
    if 6.0 <= q50 <= 13.0:        # already log10(M)
        logm = s
    elif 1e7 <= q50 <= 1e13:      # linear Msun
        logm = np.log10(s)
    else:
        logm = s
    # drop sentinels/implausible ranges (e.g., -99)
    logm = logm.mask(~np.isfinite(logm) | (logm < 6.0) | (logm > 13.0))
    return logm


def kpc_per_arcsec(z):
    if not np.isfinite(z) or z <= 0:
        return np.nan
    return COSMO.angular_diameter_distance(z).to_value("kpc") * (np.pi / 648000.0)


def _iter_tile_paths():
    for fp in sorted(glob.glob(os.path.join(MULTIBAND_DIR, "*.fits"))):
        base = os.path.basename(fp).lower()
        if base.endswith(".fits.wc") or base.endswith(".wc.fits") or base.endswith(".l1.fits"):
            continue
        yield fp


def collect_sizes(want_ids: set[str]) -> pd.DataFrame:
    """Return DataFrame(ID_key, size_arcsec). Uses cache if present; else scans tiles once."""
    if os.path.exists(CACHE):
        df = pd.read_csv(CACHE)
        df["ID_key"] = df["ID_key"].astype(str).str.strip()
        return df[df["ID_key"].isin(want_ids)][["ID_key", "size_arcsec"]].copy()

    rows = []
    scanned = 0
    for fp in _iter_tile_paths():
        scanned += 1
        if scanned % 100 == 0:
            print(f"[INFO] Scanned {scanned} tiles...")

        try:
            t = fits_to_df(fp)
        except Exception as e:
            print(f"[WARN] Skipping {fp}: {e}")
            continue

        if "ID" not in t.columns:
            continue

        t["ID_key"] = t["ID"].astype(str).str.strip()
        t = t[t["ID_key"].isin(want_ids)]
        if t.empty:
            continue

        size_arcsec = None
        have_ab = {"A_WORLD", "B_WORLD"}.issubset(set(t.columns))
        if have_ab:
            a = pd.to_numeric(t["A_WORLD"], errors="coerce")
            b = pd.to_numeric(t["B_WORLD"], errors="coerce")
            size_arcsec = 3600.0 * np.sqrt(np.clip(a, 0, None) * np.clip(b, 0, None))
        elif "FLUX_RADIUS" in t.columns:
            fr = pd.to_numeric(t["FLUX_RADIUS"], errors="coerce")
            size_arcsec = fr * PIX_ARCSEC

        if size_arcsec is None:
            continue

        part = pd.DataFrame({"ID_key": t["ID_key"], "size_arcsec": size_arcsec})
        rows.append(part)

    if not rows:
        return pd.DataFrame(columns=["ID_key", "size_arcsec"])

    out = pd.concat(rows, ignore_index=True)
    out = out.groupby("ID_key", as_index=False)["size_arcsec"].median()
    out.to_csv(CACHE, index=False)
    return out


def main():
    if not os.path.isdir(MULTIBAND_DIR):
        raise SystemExit(f"Missing multiband dir: {MULTIBAND_DIR}")

    # Bright + LePhare (endianness-safe)
    B = fits_to_df(BRIGHT)
    L = fits_to_df(LEPH)

    need = ["ID", "RAJ2000", "DECJ2000", "zphot_ANNz2"]
    missing = [c for c in need if c not in B.columns]
    if missing:
        raise SystemExit(f"Bright sample missing columns: {missing}")

    _, mass_raw = pick_mass_series(L)
    Mlog = as_log10(mass_raw)

    B2 = B[need].copy()
    B2["ID_key"] = B2["ID"].astype(str).str.strip()
    L2 = L[["ID"]].copy()
    L2["ID_key"] = L2["ID"].astype(str).str.strip()
    L2["Mstar_log10"] = Mlog

    D = B2.merge(L2[["ID_key", "Mstar_log10"]], on="ID_key", how="inner")
    D = D[D["Mstar_log10"].notna()].copy()  # drop sentinels/out-of-range
    print(f"[INFO] Lenses after join & cleaned masses: {len(D):,}")

    # collect true sizes (cached if available)
    want_ids = set(D["ID_key"].unique())
    S = collect_sizes(want_ids)
    if S.empty:
        print("[WARN] No sizes matched from tiles (check MULTIBAND_DIR or ID formatting).")
        with open(OUTCSV, "w") as f:
            f.write("lens_id,ra_deg,dec_deg,z_lens,R_G_kpc,Mstar_log10,R_G_bin,Mstar_bin\n")
        print(f"Wrote {OUTCSV} with 0 rows.")
        return

    # join sizes
    J = D.merge(S, on="ID_key", how="inner").copy()
    if J.empty:
        print("[WARN] Sizes joined 0 rows; check ID matching.")
        with open(OUTCSV, "w") as f:
            f.write("lens_id,ra_deg,dec_deg,z_lens,R_G_kpc,Mstar_log10,R_G_bin,Mstar_bin\n")
        print(f"Wrote {OUTCSV} with 0 rows.")
        return

    # compute R_G (kpc)
    z = pd.to_numeric(J["zphot_ANNz2"], errors="coerce").to_numpy()
    kpc_per_as = np.vectorize(kpc_per_arcsec)(z)
    J["R_G_kpc"] = pd.to_numeric(J["size_arcsec"], errors="coerce") * kpc_per_as

    # quick sanity
    valid_rg = pd.to_numeric(J["R_G_kpc"], errors="coerce")
    valid_rg = valid_rg[np.isfinite(valid_rg)]
    if valid_rg.size:
        q = np.quantile(valid_rg, [0.05, 0.25, 0.5, 0.75, 0.95])
        print(f"R_G_kpc quantiles [5/25/50/75/95%]: [{q[0]:.2f} {q[1]:.2f} {q[2]:.2f} {q[3]:.2f} {q[4]:.2f}]")

    # binning
    J["R_G_bin"] = pd.cut(J["R_G_kpc"], RG_EDGES, right=False)
    J["Mstar_bin"] = pd.cut(J["Mstar_log10"], MS_EDGES, right=False)
    K = J.dropna(subset=["R_G_bin", "Mstar_bin"]).copy()

    out = pd.DataFrame({
        "lens_id": K["ID_key"],
        "ra_deg": pd.to_numeric(K["RAJ2000"], errors="coerce"),
        "dec_deg": pd.to_numeric(K["DECJ2000"], errors="coerce"),
        "z_lens": pd.to_numeric(K["zphot_ANNz2"], errors="coerce"),
        "R_G_kpc": pd.to_numeric(K["R_G_kpc"], errors="coerce"),
        "Mstar_log10": pd.to_numeric(K["Mstar_log10"], errors="coerce"),
        "R_G_bin": K["R_G_bin"].astype(str),
        "Mstar_bin": K["Mstar_bin"].astype(str),
    })
    out.to_csv(OUTCSV, index=False)

    # per-bin counts
    try:
        piv = out.assign(
            R_G_bin_cat=pd.Categorical(out["R_G_bin"], ordered=True),
            Mstar_bin_cat=pd.Categorical(out["Mstar_bin"], ordered=True),
        ).pivot_table(index="R_G_bin_cat", columns="Mstar_bin_cat",
                      values="lens_id", aggfunc="count", fill_value=0)
        print(piv)
    except Exception:
        pass

    print(f"Wrote {OUTCSV} with {len(out):,} rows.")


if __name__ == "__main__":
    main()
